///
/// Copyright (C) 2016, Dependable Systems Laboratory, EPFL
/// Copyright (C) 2015-2016, Cyberhaven
///
/// Permission is hereby granted, free of charge, to any person obtaining a copy
/// of this software and associated documentation files (the "Software"), to deal
/// in the Software without restriction, including without limitation the rights
/// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
/// copies of the Software, and to permit persons to whom the Software is
/// furnished to do so, subject to the following conditions:
///
/// The above copyright notice and this permission notice shall be included in all
/// copies or substantial portions of the Software.
///
/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
/// SOFTWARE.
///

#include <s2e/ConfigFile.h>
#include <s2e/Plugins/VulnerabilityAnalysis/Recipe/Recipe.h>
#include <s2e/S2E.h>
#include <s2e/Utils.h>

#include <klee/util/ExprUtil.h>

#include "PovGenerationPolicy.h"

using namespace klee;

namespace s2e {
namespace plugins {

S2E_DEFINE_PLUGIN(PovGenerationPolicy, "PovGenerationPolicy S2E plugin", "", "Recipe", "ProcessExecutionDetector",
                  "PovGenerator");

void PovGenerationPolicy::initialize() {

    // Getting plugin configuration
    ConfigFile *cfg = s2e()->getConfig();
    m_maxPovCount = cfg->getInt(getConfigKey() + ".maxPovCount", 5);
    m_maxCrashCount = cfg->getInt(getConfigKey() + ".maxCrashCount", 5);
    m_crashCount = 0;

    // Detecting vulnerability sources
    m_process = s2e()->getPlugin<ProcessExecutionDetector>();
    m_recipe = s2e()->getPlugin<recipe::Recipe>();
    m_recipe->onPovReady.connect(sigc::bind(sigc::mem_fun(*this, &PovGenerationPolicy::onPovReadyHandler), false));

    s2e()->getCorePlugin()->onSymbolicAddress.connect(sigc::mem_fun(*this, &PovGenerationPolicy::onSymbolicAddress),
                                                      fsigc::signal_base::HIGH_PRIORITY);

    // Detecting platform
    m_decreeMonitor = s2e()->getPlugin<DecreeMonitor>();
    m_linuxMonitor = s2e()->getPlugin<LinuxMonitor>();
    m_windowsCrashMonitor = s2e()->getPlugin<WindowsCrashMonitor>();

    if (m_decreeMonitor) {
        m_decreeMonitor->onSegFault.connect(sigc::mem_fun(*this, &PovGenerationPolicy::onSegFault));
    } else if (m_linuxMonitor) {
        m_linuxMonitor->onSegFault.connect(sigc::mem_fun(*this, &PovGenerationPolicy::onSegFault));
    } else if (m_windowsCrashMonitor) {
        m_windowsCrashMonitor->onUserModeCrash.connect(sigc::mem_fun(*this, &PovGenerationPolicy::onSegFaultWinUser));
        m_windowsCrashMonitor->onKernelModeCrash.connect(
            sigc::mem_fun(*this, &PovGenerationPolicy::onSegFaultWinKernel));
    }

    m_povGenerator = static_cast<pov::PovGenerator *>(s2e()->getPlugin("PovGenerator"));

    // The test case generator if present will not generate proper PoVs and may flood
    // with incorrect test cases, so we disable it here.
    auto tcgen = s2e()->getPlugin<testcases::TestCaseGenerator>();
    if (tcgen) {
        tcgen->disable();
    }
}

// This will kill any states that fork because of target symbolic pc provided that there was
// already a PoV generated for at the source instruction pc.
void PovGenerationPolicy::onSymbolicAddress(S2EExecutionState *state, ref<Expr> virtualAddress,
                                            uint64_t concreteAddress, bool &concretize,
                                            CorePlugin::symbolicAddressReason reason) {
    if (!m_process->isTrackedPc(state, state->regs()->getPc(), true)) {
        return;
    }

    if (reason != CorePlugin::symbolicAddressReason::PC) {
        return;
    }

    for (const auto &it : m_uniquePovMap) {
        auto pc = std::get<0>(it.first);
        if (state->regs()->getPc() == pc) {
            s2e()->getExecutor()->terminateState(*state, "Killing state because that PC has already generated a PoV");
        }
    }
}

void PovGenerationPolicy::onPovReadyHandler(S2EExecutionState *state, const PovOptions &opt,
                                            const std::string &recipeName, bool isCrash) {
    getInfoStream(state) << "Generating PoV type " << opt.m_type << " at " << hexval(opt.m_faultAddress)
                         << " from recipe '" << recipeName << "'\n";

    UniquePovKey uniquePovKey = std::make_tuple(opt.m_faultAddress, opt.m_type, recipeName);
    if (m_uniquePovMap[uniquePovKey] >= m_maxPovCount) {
        getDebugStream(state) << "PoV limit reached\n";
        return;
    }
    m_uniquePovMap[uniquePovKey]++;

    std::string prefix;
    if (recipeName.length()) {
        std::stringstream povFilenameSS;
        povFilenameSS << "recipe-" << recipeName;
        prefix = povFilenameSS.str();
    } else if (isCrash) {
        prefix = "crash";
    }

    std::vector<std::string> filePaths;
    if (!m_povGenerator->generatePoV(state, opt, prefix, filePaths)) {
        getWarningsStream(state) << "Failed to generate PoV\n";
        return;
    }

    onPovReady.emit(state, opt, recipeName, filePaths, isCrash ? CRASH : POV);
}

void PovGenerationPolicy::onSegFaultWinUser(S2EExecutionState *state, const WindowsUserModeCrash &crash) {
    onSegFault(state, crash.Pid, crash.ExceptionAddress);
}

void PovGenerationPolicy::onSegFaultWinKernel(S2EExecutionState *state,
                                              const vmi::windows::BugCheckDescription &crash) {
    // TODO: implement this properly, we don't have an easy way
    // to get the required information for now.
    onSegFault(state, 0, 0);
}

void PovGenerationPolicy::onSegFault(S2EExecutionState *state, uint64_t pid, uint64_t pc) {
    if (!m_process->isTracked(state, pid)) {
        // Only dump memory on decree, because it has relatively small binaries
        if (m_decreeMonitor) {
            std::stringstream ss;
            m_decreeMonitor->dumpUserspaceMemory(state, ss);
            getWarningsStream(state) << ss.str();
        }

        state->disassemble(getWarningsStream(state) << "\n", pc, 64);

        s2e_assert(state, false, "Untracked pid=" << hexval(pid) << " segfaulted");
    }

    // TODO: might not be needed once we can figure out the real instruction of the crash
    if (m_crashCount > m_maxCrashCount) {
        getDebugStream(state) << "Reached max crash count limit\n";
        return;
    }

    ++m_crashCount;

    // XXX: it's more useful to report the address of the last instructions
    PovOptions opts;
    opts.m_faultAddress = pc;
    onPovReadyHandler(state, opts, "", true);
}

} // namespace plugins
} // namespace s2e
